-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add warm-up functionality with tensor to trajectory helper functions #224
base: master
Are you sure you want to change the base?
Conversation
Add function to generate trajectories from states and actions tensors Add function to crudely warmup a GFN (early stopping or other tricks not included)
…ging since every other GFN loss method returns tensor
Thank you for the PR. Could you please elaborate a little bit more on it? What use-case are you targeting? Where do you use the new functions? Is there a way to test them and see their effects in the repo? Thanks |
Hi Salem! Yes, sorry, I contacted Joseph via Slack prior to the PR, but I should've given more detail on here. These functions are provided as a means to generate warmup trajectories from external state-action-tensors (e.g.\ expert knowledge, or another algorithm's output). My rationale for PR'ing these simple functions is that I found the whole process to be non-trivial when looking at the sources/docs (namely, watch for the def states_actions_tns_to_traj(
states_tns: torch.Tensor,
actions_tns: torch.Tensor,
env: DiscreteEnv,
) -> Trajectories: is a utility function that maps state-tensors and actions to a def warm_up(
replay_buf: ReplayBuffer,
optimizer: torch.optim.Optimizer,
gfn: GFlowNet,
env: Env,
n_steps: int,
batch_size: int,
recalculate_all_logprobs=True,
): is a training loop over a fixed replay buffer, but does not assume that some log-probs were computed in the I can write some unit-tests for the If you have any other feedback, send it my way so that we can implement it and follow your philosophy more closely. Edit: I clarified why the warm-up function was important to this PR |
Thank you for the PR
Possible docstrign to add:
For the warm_up, a docstring would be appreciated. I am not sure why |
src/gfn/utils/training.py
Outdated
for epoch in t: | ||
training_trajs = replay_buf.sample(batch_size) | ||
optimizer.zero_grad() | ||
if isinstance(gfn, TBGFlowNet): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with #231 , this could be changed to a cleaner test (if it's a PFBasedGFlowNet)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Seeing your commit, I think this would be cleaner.
Add doscrings Add input validation (as proposed by saleml) Add PFBasedGFlowNet verification instead of only TBGFNs (needs merge of GFNOrg#231)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, first I want to apologize for taking so long to review this. I hit a bit of a lull over Dec / early Jan and have been playing catchup.
This is a really nice PR, and a feature I'd be excited to use myself in some of the applications I've been looking at. My only request revolves around the use of the dummy log_probs
- if our library is working properly, it should function as intended using log_probs=None
, and if not, we should fix the downstream elements if they're misbehaving, because this is the intended use of the Trajectories
container.
Awesome contribution, thank you very much!
# WARNING: This is sketchy. Create dummy values to avoid indexing / batch shape errors. | ||
# WARNING: Assumes gfn.loss() uses recalculate_all_logprobs=True (thus only PFBasedGFlowNet are supported right now)!! | ||
# WARNING: To reviewers: Can we bypass needing to define this? | ||
log_probs = torch.full(size=(len(actions), 1), fill_value=0, dtype=torch.float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_probs can be None
- which will trigger a recalculate downstream (or, if it doesn't, we should fix that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In other words you can remove 129-132
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! I will do some testing this week to see if it behaves as I expect before removing.
actions, | ||
log_rewards=log_rewards, | ||
when_is_done=when_is_done, | ||
log_probs=log_probs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_probs=None
env.actions_from_tensor(a.unsqueeze(0).unsqueeze(0)) for a in actions_tns | ||
] | ||
|
||
# stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really appreciate this comment
I'm hoping to get your insight on how to make this better. Some parts of the code are sketchy and have been highlighted with a WARNING tag in the comments.